
import numpy as np
import itertools
from scipy import linalg
from enum import Enum

from scipy.stats import entropy





class Constants(Enum):
    EPSILON = 1e-100


class NotFittedError(ValueError, AttributeError):
    """Exception class to raise if estimator is used before fitting.
    This class inherits from both ValueError and AttributeError to help with
    exception handling and backward compatibility.
    """


#class LegendreDecomposition(TransformerMixin, BaseEstimator):
class LegendreDecomposition:

    def __init__(self, core_size=2, solver='ng',
                 tol=1e-4, max_iter=10, learning_rate=0.1,
                 random_state=None, shuffle=False, verbose=0):
        self.core_size = core_size
        self.solver = solver
        self.tol = tol
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.random_state = random_state
        self.shuffle = shuffle
        self.verbose = verbose
        if self.verbose:
            np.set_printoptions(threshold=200)

    def fit_transform(self, P, coordinates):

        self.theta = self._legendre_decomposition(P,coordinates)
        Q = self._compute_Q(self.theta, self.beta) * P.sum()
        self.reconstruction_err_ = self._calc_rmse(P, Q)

        P_flat = P.flatten()
        Q_flat = Q.flatten()
        P_flat = P_flat / np.sum(P_flat)
        Q_flat = Q_flat / np.sum(Q_flat)
        self.kl_div = entropy(P_flat, Q_flat, base=np.e)

        size = P_flat.size
        uniform_array = np.full(P_flat.shape, 1 / size)
        #kl_symmatry=entropy(uniform_array, Q_flat, base=np.e)
        kl_symmatry = entropy( Q_flat,uniform_array, base=np.e)
        fan_kl_div=entropy(Q_flat, P_flat, base=np.e)
        #print('1,2,1+2,3',kl_symmatry,self.kl_div,kl_symmatry+self.kl_div,entropy( P_flat,uniform_array, base=np.e))
        #kl_symmatry = entropy(Q_flat,uniform_array , base=np.e)
        #print("kl_symm,self_kl,KL Divergence_symm_cha:",kl_symmatry,self.kl_div,kl_symmatry-self.kl_div )

        entropy_uniform=entropy(uniform_array, base=np.e)
        entropy_P_flat = entropy(P_flat, base=np.e)
        #print('entropy_uniform,entropy_P_flat,cha,bi',entropy_uniform,entropy_P_flat,entropy_uniform-entropy_P_flat,entropy_P_flat/entropy_uniform)
        return Q

    def fit(self, P, y=None, **params):

        self.fit_transform(P, **params)

        return self

    def transform(self):

        self._check_is_fitted(self)

        return self._compute_Q(self.theta, self.beta)

    def _compute_Q_(self, theta, beta):

        idx = theta.shape
        order = len(theta.shape)
        theta_sum = np.zeros(theta.shape)

        ranges = [range(s) for s in idx]
        for indices in itertools.product(*ranges):
            # print('incides=', indices)
            slices = tuple(slice(None, irt + 1) for irt in indices)
            # print('slices=', slices)
            theta_sum[tuple(indices)] = np.sum(theta[slices])

        Q = np.exp(theta_sum)
        psi = Q.sum()
        Q /= psi
        return Q

    # Using DP must help faster.


    def _compute_eta_(self, Q):

        shape = Q.shape
        order = len(shape)
        eta = np.zeros(shape)
        # eta1 = np.zeros(Q.shape)
        if order == 2:
            for i, j in itertools.product(range(shape[0]), range(shape[1])):
                eta[i, j] = Q[np.arange(i, shape[0])][:, np.arange(j, shape[1])].sum()
        else:
            ranges = [range(s) for s in shape]
            for indices in itertools.product(*ranges):
                # print('indices=', indices)
                # slices = [slice(idx, s) for idx, s in zip(indices, shape)]
                slices = tuple(slice(idx, None) for idx in indices)
                # print('slices=', slices)
                eta[tuple(indices)] = Q[tuple(slices)].sum()
        return eta

    # Using DP must help faster.
    def _compute_eta(self, Q):
        shape = Q.shape
        order = len(shape)
        eta = np.zeros(shape)
        # eta1 = np.zeros(Q.shape)
        if order == 2:
            for i, j in itertools.product(range(shape[0]), range(shape[1])):
                eta[i, j] = Q[np.arange(i, shape[0])][:, np.arange(j, shape[1])].sum()
        else:
            ranges = [range(s) for s in shape]
            for indices in itertools.product(*ranges):
                # print('indices=', indices)
                # slices = [slice(idx, s) for idx, s in zip(indices, shape)]
                slices = tuple(slice(idx, None) for idx in indices)
                # print('slices=', slices)
                eta[tuple(indices)] = Q[tuple(slices)].sum()
        return eta

    def _compute_jacobian(self, eta, beta):

        beta = np.array(beta)  
        size = len(beta)
        n_dims = beta.shape[1]  
        g = np.zeros((size, size))

       
        if eta.ndim != n_dims:
            raise ValueError(f"Eta must be a {n_dims}-dimensional array, but got {eta.ndim}-dimensional.")

       
        indices = np.indices((size, size))  
        I, J = indices[0], indices[1]

       
        max_index = np.zeros((size, size, n_dims), dtype=int)

       
        for dim in range(n_dims):
            max_index[:, :, dim] = np.maximum(beta[I, dim], beta[J, dim])

       
        eta_max = eta[tuple(max_index[..., dim] for dim in range(n_dims))]

       
        eta_values = eta[tuple(beta[:, dim] for dim in range(n_dims))]

       
        eta_prod = eta_values.reshape(-1, 1) * eta_values

       
        g = eta_max - eta_prod

        return g

    def _compute_residual(self, eta, beta): 

        res = np.sqrt(np.mean([(eta[v] - self.eta_hat[v])**2 for v in beta]))
        return res

    def _calc_rmse(self, P, Q):

        print('rmse=',np.sqrt(np.mean(np.square(P - Q))))
        ans = np.sqrt(np.sum((P - Q) ** 2))/np.sqrt(np.sum(P ** 2))

        #return np.sqrt(np.mean(np.square(P - Q)))
        return ans
    def _check_is_fitted(self, estimator, attributes=None, msg=None, all_or_any=all):

        if msg is None:
            msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
                "appropriate arguments before using this estimator.")

        if not hasattr(estimator, 'fit'):
            raise TypeError("%s is not an estimator instance." % (estimator))

        if attributes is not None:
            if not isinstance(attributes, (list, tuple)):
                attributes = [attributes]
            attrs = all_or_any([hasattr(estimator, attr) for attr in attributes])
        else:
            attrs = [v for v in vars(estimator)
                    if v.endswith("_") and not v.startswith("__")]

        if not attrs:
            raise NotFittedError(msg % {'name': type(estimator).__name__})

    def _normalizer(self, P):

        # TODO: check if tensor has NaN values.
        return P / np.sum(P)

    def _initialize(self):

        theta = np.zeros(self.shape)
        #theta = 0.4*np.ones(self.shape)
        #theta[0, 0] = 0

        return theta

    def _gen_norm(self, shape):

        order = len(shape)
        beta = []
        temp_beta = []

        if order == 2:
            # B_1
            for i in range(shape[0]):
                if self.basis_index[i, 0] == 0:
                    temp_beta.append((i, 0))
            for j in range(shape[1]):
                if self.basis_index[0, j] == 0:
                    temp_beta.append((0, j))

        elif order == 3:
            # B_1
            for i in range(shape[0]):
                if self.basis_index[i, 0, 0] == 0:
                    temp_beta.append((i, 0, 0))
            for j in range(shape[1]):
                if self.basis_index[0, j, 0] == 0:
                    temp_beta.append((0, j, 0))
            for k in range(shape[2]):
                if self.basis_index[0, 0, k] == 0:
                    temp_beta.append((0, 0, k))

            # B_2
            if self.core_size < shape[0]:
                index_0 = [int(c * np.floor(shape[0] / self.core_size)) for c in range(self.core_size)]
            else:
                index_0 = [c for c in range(shape[0])]

            if self.core_size < shape[1]:
                index_1 = [int(c * np.floor(shape[1] / self.core_size)) for c in range(self.core_size)]
            else:
                index_1 = [c for c in range(shape[1])]

            if self.core_size < shape[2]:
                index_2 = [int(c * np.floor(shape[2] / self.core_size)) for c in range(self.core_size)]
            else:
                index_2 = [c for c in range(shape[2])]

            for i in index_0:
                for j in index_1:
                    if self.basis_index[i, j, 0] == 0:
                        temp_beta.append((i, j, 0))
                for k in index_2:
                    if self.basis_index[i, 0, k] == 0:
                            temp_beta.append((i, 0, k))

        else:
            raise NotImplementedError("Order of input tensor should be 2 or 3. Order: {}.".format(order))

        for c in range(len(temp_beta)):
            if self.basis_index[temp_beta[c]] == 0:
                beta.append(temp_beta[c])
                self.basis_index[temp_beta[c]] = 1

        return beta

    def _get_P_value(self, v):
        return self.P[v]

    def _gen_core(self, shape):

        order = len(shape)
        beta = []
        # B_3
        for i in range(shape[0]):
            temp_beta = []
            c_size = self.core_size
            if order == 2:
                for j in range(shape[1]):
                    if self.basis_index[i, j] == 0:
                        temp_beta.append((i, j))    
            elif order == 3:
                for j, k in itertools.product(range(shape[1]), range(shape[2])):
                    if self.basis_index[i, j, k] == 0:
                        temp_beta.append((i, j, k))
            else:
                raise NotImplementedError("Order of input tensor should be 2 or 3. Order: {}.".format(order))

            if self.shuffle:
                np.random.seed(seed=self.random_state)
                np.random.shuffle(temp_beta)
            else:
                temp_beta.sort(key=self._get_P_value)
            print('temp_beta=',temp_beta)

            if len(temp_beta) < c_size:
                c_size = len(temp_beta)
            for c in range(c_size):
                print('c=',c)
                print('temp_beta[c]',temp_beta[c])
                if self.basis_index[temp_beta[c]] == 0:
                    beta.append(temp_beta[c])
                    self.basis_index[temp_beta[c]] = 1

        return beta

    def _gen_basis(self, shape):

        self.basis_index = np.zeros(shape) 
        beta = []
        # exclude all zero basis for a technical reason.

        self.basis_index[tuple(np.zeros(len(shape)).astype(int))] = 1 
        print('666=',len(shape))
        print('233=',tuple(np.zeros(len(shape)).astype(int)))
        print('1self.basis_index=',self.basis_index)
        if self.solver == 'ng':
            beta += self._gen_norm(shape) 
        beta += self._gen_core(shape)

        return beta

    def _fit_gradient_descent(self, P, beta):

        theta = self._initialize()  
        self.eta_hat = self._compute_eta(P) 

        total_sum = np.sum(P)
        normal_dis=np.full(total_sum, 1 / total_sum)
        eta_normal_dis=self._compute_eta(normal_dis)
        #print('eta_normal_cha=',self.eta_hat-eta_normal_dis)

        self.res = 0.
        if self.verbose:
            print("\n\n============= theta =============")
            print(theta)
            print("\n\n============= eta_hat =============")
            print(self.eta_hat)

        for n_iter in range(self.max_iter):
            eta = self._compute_eta(self._compute_Q(theta, beta))
            if self.verbose:
                print("\n\n============= iteration: {}, eta =============".format(n_iter))
                print(eta)

            prev_res = self.res
            self.res = self._compute_residual(eta, beta)
            if self.verbose:
                print("n_iter: {}, Residual: {}".format(n_iter, self.res))

            # check convergence
     #       if (self.res <= self.tol) or (prev_res <= self.res and Constants.EPSILON.value <= prev_res):
     #           self.converged_n_iter = n_iter
     #           print("Convergence of theta at iteration: {}".format(self.converged_n_iter))
     #           break

            for v in beta:
                # \theta_v \gets \theta_v - \epsilon \times (\eta_v - \hat{\eta_v})
                grad = self._compute_eta(self._compute_Q(theta, beta)) - self.eta_hat
                theta[v] -= self.learning_rate * grad[v]

            if self.verbose:
                print("\n\n============= iteration: {}, theta =============".format(n_iter))
                print(theta)

        return theta

    def _fit_natural_gradient(self, P, beta):

        theta = self._initialize()
        self.eta_hat = self._compute_eta(P)

        total_sum = P.size
        normal_dis = np.full(total_sum, 1 / total_sum)
        normal_dis = normal_dis.reshape(P.shape)
        eta_normal_dis = self._compute_eta(normal_dis)
        #print('eta_normal_cha=', self.eta_hat - eta_normal_dis)


        self.res = 0.
        theta_vec = np.array([theta[v] for v in beta])
        if self.verbose:
            print("\n\n============= theta =============")
            print(theta)
            print("\n\n============= eta_hat =============")
            print(self.eta_hat)

        for n_iter in range(self.max_iter):
            eta = self._compute_eta(self._compute_Q(theta, beta))

            Q1=self._compute_Q(theta, beta)
            P_flat = P.flatten()
            Q_flat = Q1.flatten()
            P_flat = P_flat / np.sum(P_flat)
            Q_flat = Q_flat / np.sum(Q_flat)
            self.kl_div = entropy(P_flat, Q_flat, base=np.e)
            #print("KL Divergence each:", self.kl_div)

            if self.verbose:
                print("\n\n============= iteration: {}, eta =============".format(n_iter))
                print(eta)

            prev_res = self.res
            self.res = self._compute_residual(eta, beta)
            if self.verbose:
                print("n_iter: {}, Residual: {}".format(n_iter, self.res))

            # check convergence
            if (self.res <= self.tol) or (prev_res <= self.res and Constants.EPSILON.value <= prev_res):
                self.converged_n_iter = n_iter
                print("Convergence of theta at iteration: {}".format(self.converged_n_iter))
                break

            # compute \delta\eta and Fisher information matrix.
            delta_eta = eta - self.eta_hat
            eta_vec = np.array([delta_eta[v] for v in beta])
            G = self._compute_jacobian(eta, beta)
            if self.verbose:
                print("\n\n============= iteration: {}, delta_eta =============".format(n_iter))
                print(delta_eta)
                print("\n\n============= iteration: {}, eta_vec =============".format(n_iter))
                print(eta_vec)
                print("\n\n============= iteration: {}, G =============".format(n_iter))
                print(G)

            # TODO: Algorithm 7, Information Geometric Approaches for
            # Neural Network Algorithms to compute G inverse
            try:
                theta_vec -= 0.07*np.dot(np.linalg.inv(G), eta_vec)
                #eigenvalues1, eigenvectors1 = np.linalg.eig(G)


                #theta_vec -=0.0000002* np.dot(np.linalg.inv(G), eta_vec)
                # theta_vec -= np.linalg.solve(G, eta_vec)
                # theta_vec -= np.dot(linalg.inv(G), eta_vec)
            except:
                theta_vec -= np.dot(np.linalg.pinv(G), eta_vec)

            if self.verbose:
                try:
                    G_inv = np.linalg.inv(G)
                except:
                    G_inv = np.linalg.pinv(G)
                print("\n\n============= iteration: {}, G_inverse =============".format(n_iter))
                print(G_inv)
                print("\n\n============= iteration: {}, theta_vec =============".format(n_iter))
                print(theta_vec)

            # Update theta
            for n, v in enumerate(beta):
                theta[v] = theta_vec[n]

            if self.verbose:
                print("\n\n============= iteration: {}, theta =============".format(n_iter))
                print(theta)
        #print('legneta=',eta)
        return theta

    def _legendre_decomposition(self, P,coordinates):

        self.shape = P.shape
        order = len(P.shape)
        #if order not in (2, 3):
        #    raise NotImplementedError("Order of input tensor should be 2 or 3. Order: {}.".format(order))

        # normalize tensor
        self.P = self._normalizer(P) 
        #self.beta = self._gen_basis(self.shape)
        self.beta = coordinates
        if self.verbose:
            print("\n\n============= beta =============") 
            print(self.beta)

        if self.solver == 'ng':
            theta = self._fit_natural_gradient(self.P, self.beta)
        elif self.solver == 'gd':
            theta = self._fit_gradient_descent(self.P, self.beta)
        else:
            raise ValueError("Invalid solver {}.".format(self.solver))

        return theta

    def _compute_Q(self, theta, beta):
        idx = theta.shape
        order = len(theta.shape)
        theta_sum = np.zeros(theta.shape)

        ranges = [range(s) for s in idx]
        for indices in itertools.product(*ranges):
            # print('incides=', indices)
            slices = tuple(slice(None, irt + 1) for irt in indices)
            # print('slices=', slices)
            theta_sum[tuple(indices)] = np.sum(theta[slices])

        Q = np.exp(theta_sum)
        psi = Q.sum()
        Q /= psi

        return Q


